{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.models as torch_models\n",
    "from torch import nn\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Special_Adapter_v1(nn.Module):\n",
    "    def __init__(self, in_planes:int, mid_planes:int, kernel_size:int, use_alpha=True, conv_group=1):\n",
    "        super().__init__()\n",
    "        self.in_planes = in_planes\n",
    "        self.mid_planes = mid_planes\n",
    "        self.conv = nn.Conv2d(in_planes, mid_planes, kernel_size=kernel_size, groups=conv_group)\n",
    "        self.bn1 = nn.BatchNorm2d(mid_planes)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.convTransposed = nn.ConvTranspose2d(mid_planes, in_planes, kernel_size=kernel_size, groups=conv_group)\n",
    "        self.bn2 = nn.BatchNorm2d(in_planes)\n",
    "        \n",
    "        self.use_alpha = use_alpha\n",
    "        if use_alpha:\n",
    "            self.alpha = nn.Parameter(torch.ones(1)*0.02)\n",
    "            print('Apply alpha!')\n",
    "    \n",
    "    def forward(self, x):\n",
    "        if isinstance(x, tuple):\n",
    "            x = x[0]\n",
    "        \n",
    "        ### original: conv+bn+ReLU+convT+bn+ReLU ###\n",
    "        out = self.conv(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.convTransposed(out)\n",
    "        out = self.bn2(out)\n",
    "        out = self.relu(out)\n",
    "        \n",
    "        if self.use_alpha:\n",
    "            out = out * self.alpha\n",
    "\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_extractor = torch_models.resnet18()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1\n",
      "bn1\n",
      "relu\n",
      "maxpool\n",
      "layer1\n",
      "layer2\n",
      "layer3\n",
      "layer4\n",
      "avgpool\n",
      "fc\n"
     ]
    }
   ],
   "source": [
    "for name, module in feature_extractor.named_children():\n",
    "    print(f'{name}')    \n",
    "    # x = module(x)\n",
    "    # if name in self.layer_names:\n",
    "    #     adapter_id = name.replace('.', '_') + '_adapters'\n",
    "    #     adapters = getattr(self, adapter_id)\n",
    "    #     for adapter in adapters:\n",
    "    #         x = adapter(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN_Adapter_Net_CIL_V2(nn.Module):\n",
    "    def __init__(self,fe):\n",
    "        super(CNN_Adapter_Net_CIL_V2, self).__init__()\n",
    "        self.feature_extractor = fe\n",
    "        self.feature_extractor.fc = nn.Identity()\n",
    "        self.layer_names = ['layer1','layer2','layer3','layer4']\n",
    "        for layer_id in self.layer_names:\n",
    "            adapter_id = layer_id.replace('.', '_')+'_adapters'\n",
    "            self.register_module(adapter_id, nn.ModuleList([]))\n",
    "        self.task_sizes = [2]\n",
    "    def forward(self,x):\n",
    "        \n",
    "        for name, module in self.feature_extractor.named_children():\n",
    "            print(f'{name}')    \n",
    "            if name in self.layer_names:\n",
    "                adapter_id = name.replace('.', '_') + '_adapters'\n",
    "                adapters = getattr(self, adapter_id)\n",
    "                b, c, h, w = x.shape\n",
    "                if len(adapters) < len(self.task_sizes):\n",
    "                    print(f'Append new adapters')\n",
    "                    getattr(self, adapter_id).append(Special_Adapter_v1(c, c, 3).cuda())\n",
    "                \n",
    "                x = x + getattr(self, adapter_id)[-1](x)\n",
    "                \n",
    "            x = module(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = CNN_Adapter_Net_CIL_V2(feature_extractor).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 3, 224, 224])"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(1, 3, 224, 224).cuda()\n",
    "x.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1\n",
      "bn1\n",
      "relu\n",
      "maxpool\n",
      "layer1\n",
      "Append new adapters\n",
      "Apply alpha!\n",
      "layer2\n",
      "Append new adapters\n",
      "Apply alpha!\n",
      "layer3\n",
      "Append new adapters\n",
      "Apply alpha!\n",
      "layer4\n",
      "Append new adapters\n",
      "Apply alpha!\n",
      "avgpool\n",
      "fc\n"
     ]
    }
   ],
   "source": [
    "res = net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 512, 1, 1])"
      ]
     },
     "execution_count": 101,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ModuleList(\n",
       "  (0): Special_Adapter_v1(\n",
       "    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu): ReLU(inplace=True)\n",
       "    (convTransposed): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.layer1_adapters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CL_Pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
